# import zipfile
# import os
# def unzip_file(zip_path, extract_to_folder):
# # Ensure the output folder exists
# os.makedirs(extract_to_folder, exist_ok=True)
# # Unzip the file
# with zipfile.ZipFile(zip_path, 'r') as zip_ref:
# zip_ref.extractall(extract_to_folder)
# print(f"Extracted '{zip_path}' to '{extract_to_folder}'")
# zip_file_path = "midis.zip"
# output_folder_path = "data"
# unzip_file(zip_file_path, output_folder_path)
# import os
# def count_files_in_folder(folder_path):
# count = sum(
# 1 for entry in os.scandir(folder_path) if entry.is_file()
# )
# print(f"Number of files in '{folder_path}': {count}")
# return count
# folder_path = "data/midis"
# count_files_in_folder(folder_path)
The dataset used for this project is GiantMIDI-Piano, a large-scale symbolic music dataset specifically tailored for classical piano music analysis and generation. It contains 10,855 MIDI files corresponding to 2,786 composers. The MIDI files are derived from live piano performances available on YouTube, which were then transcribed using a high-resolution, deep learning-based transcription system. This system is capable of capturing subtle musical elements such as dynamics, timing variations, and pedal usage - making the dataset particularly useful for expressive music generation tasks.
The MIDI files in the GiantMIDI-Piano dataset were not manually created, nor did we transcribe them ourselves. Instead, they were generated by the original dataset authors using a high-resolution automatic transcription pipeline. The source audio consisted of solo classical piano performances collected from YouTube, selected based on a curated metadata file that included composer names and piece titles. These audio recordings were then processed through a deep learning-based transcription system, designed to convert complex polyphonic piano audio into symbolic MIDI format. This system accurately detects note onsets, offsets, pitches, and velocities, capturing expressive performance elements such as timing and dynamics. The full transcription process, which spanned approximately 200 hours on a single GPU, produced over 10,000 MIDI files.
Extracts the following features:
a. Filename
b. Number of notes
c. Average pitch
d. Min/max pitch
e. Total duration
f. Polyphony (estimated by note start overlap)
g. Composer (from filename)
import pretty_midi
import os
import pandas as pd
def extract_midi_features(file_path):
try:
midi_data = pretty_midi.PrettyMIDI(file_path)
all_notes = [note for instrument in midi_data.instruments for note in instrument.notes if not instrument.is_drum]
if not all_notes:
return None
pitches = [note.pitch for note in all_notes]
start_times = [note.start for note in all_notes]
durations = [note.end - note.start for note in all_notes]
polyphony = len(set(start_times)) / midi_data.get_end_time() if midi_data.get_end_time() > 0 else 0
filename = os.path.basename(file_path)
composer = filename.split(',')[0].strip() if ',' in filename else "Unknown"
return {
'filename': filename,
'composer': composer,
'duration_sec': midi_data.get_end_time(),
'n_notes': len(all_notes),
'avg_pitch': sum(pitches) / len(pitches),
'min_pitch': min(pitches),
'max_pitch': max(pitches),
'avg_duration': sum(durations) / len(durations),
'polyphony_score': polyphony
}
except Exception as e:
print(f"Failed to process {file_path}: {e}")
return None
midi_dir = "data/midis"
feature_list = []
for file in os.listdir(midi_dir):
if file.endswith('.mid'):
path = os.path.join(midi_dir, file)
features = extract_midi_features(path)
if features:
feature_list.append(features)
midi_df = pd.DataFrame(feature_list)
midi_df.head()
| filename | composer | duration_sec | n_notes | avg_pitch | min_pitch | max_pitch | avg_duration | polyphony_score | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | Kirchner, Theodor, 8 Romances, Op.22, m90Rf9AY... | Kirchner | 124.684896 | 1159 | 58.195858 | 24 | 91 | 0.655047 | 8.782138 |
| 1 | Jacobi, Karl, Introduction and Polonaise, Op.9... | Jacobi | 506.805990 | 10213 | 68.113287 | 24 | 94 | 0.208727 | 17.588585 |
| 2 | Chopin, Frédéric, Nocturne in C minor, B.1... | Chopin | 216.692708 | 860 | 65.243023 | 31 | 101 | 1.508937 | 3.908761 |
| 3 | Wieniawski, Józef, Valse-caprice, Op.46, 3BV... | Wieniawski | 502.740885 | 4006 | 65.908637 | 22 | 100 | 0.631758 | 7.270147 |
| 4 | Beethoven, Ludwig van, 12 Variations on the Ru... | Beethoven | 683.522135 | 6161 | 65.257101 | 26 | 90 | 0.359308 | 8.457663 |
print(midi_df.columns)
Index(['filename', 'composer', 'duration_sec', 'n_notes', 'avg_pitch',
'min_pitch', 'max_pitch', 'avg_duration', 'polyphony_score'],
dtype='object')
csv_path = "midi_features_head.csv"
midi_df.to_csv(csv_path, index=False)
import pandas
midi_df = pandas.read_csv('midi_features_head.csv')
# print(midi_df)
midi_df.head()
| filename | composer | duration_sec | n_notes | avg_pitch | min_pitch | max_pitch | avg_duration | polyphony_score | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | Kirchner, Theodor, 8 Romances, Op.22, m90Rf9AY... | Kirchner | 124.684896 | 1159 | 58.195858 | 24 | 91 | 0.655047 | 8.782138 |
| 1 | Jacobi, Karl, Introduction and Polonaise, Op.9... | Jacobi | 506.805990 | 10213 | 68.113287 | 24 | 94 | 0.208727 | 17.588585 |
| 2 | Chopin, Frédéric, Nocturne in C minor, B.1... | Chopin | 216.692708 | 860 | 65.243023 | 31 | 101 | 1.508937 | 3.908761 |
| 3 | Wieniawski, Józef, Valse-caprice, Op.46, 3BV... | Wieniawski | 502.740885 | 4006 | 65.908637 | 22 | 100 | 0.631758 | 7.270147 |
| 4 | Beethoven, Ludwig van, 12 Variations on the Ru... | Beethoven | 683.522135 | 6161 | 65.257101 | 26 | 90 | 0.359308 | 8.457663 |
import pandas
import pandas as pd
metadata_df = pd.read_csv(
'full_music_pieces_youtube_similarity_pianosoloprob_split.csv',
delimiter='\t',
quotechar='"',
on_bad_lines='skip'
)
# metadata_df.head()
# print(midi_df)
metadata_df.head()
| surname | firstname | music | nationality | birth | death | youtube_title | youtube_id | similarity | piano_solo_prob | audio_name | audio_duration | giant_midi_piano | split | surname_in_youtube_title | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | A. | Jag | Je t'aime Juliette | unknown | unknown | unknown | Je t'aime Juliette - A. Jag | OXC7Fd0ZN8o | 1.000000 | 6.848339e-01 | A., Jag, Je t'aime Juliette, OXC7Fd0ZN8o | 69.553469 | 1.0 | validation | 1.0 |
| 1 | Aadler | C. A. | Floating Islands | unknown | unknown | unknown | Mind-Boggling Off-Grid FLOATING Island HOMESTEAD | wPhWfjyqCBs | 0.333333 | NaN | NaN | NaN | NaN | NaN | NaN |
| 2 | Aagesen | Truid | Cantiones trium vocum | Danish | 1500 | 1600 | 2nd Edition, Motecta Trium Vocum | iWROE7EzwlE | 0.500000 | NaN | NaN | NaN | NaN | NaN | NaN |
| 3 | Aaron | Michael | Piano Course | unknown | unknown | unknown | Michael Aaron Piano Course Lessons Grade 1 Com... | V8WvKK-1b2c | 1.000000 | 7.859141e-01 | Aaron, Michael, Piano Course, V8WvKK-1b2c | 1556.569469 | 1.0 | validation | 1.0 |
| 4 | Aarons | Alfred E. | Brother Bill | unknown | unknown | unknown | Brother Bill | Giet2Krl6Ww | 0.666667 | 7.822375e-07 | Aarons, Alfred E., Brother Bill, Giet2Krl6Ww | 181.333469 | 0.0 | NaN | 0.0 |
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="whitegrid")
plt.rcParams["figure.figsize"] = (10, 6)
plt.figure()
sns.histplot(midi_df['duration_sec'], bins=50, kde=True)
plt.title("Distribution of Piece Durations")
plt.xlabel("Duration (seconds)")
plt.ylabel("Number of Pieces")
plt.show()
plt.figure()
sns.histplot(midi_df['n_notes'], bins=50, kde=True, color='orange')
plt.title("Distribution of Number of Notes Per Piece")
plt.xlabel("Number of Notes")
plt.ylabel("Number of Pieces")
plt.show()
plt.figure()
top_composers = midi_df['composer'].value_counts().nlargest(10)
sns.barplot(x=top_composers.values, y=top_composers.index, palette="mako")
plt.title("Top 10 Most Frequent Composers")
plt.xlabel("Number of Pieces")
plt.ylabel("Composer")
plt.show()
/tmp/ipykernel_439/2009332229.py:4: FutureWarning: Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect. sns.barplot(x=top_composers.values, y=top_composers.index, palette="mako")
plt.figure()
sns.scatterplot(data=midi_df, x='duration_sec', y='n_notes', alpha=0.6)
plt.title("Duration vs. Number of Notes")
plt.xlabel("Duration (seconds)")
plt.ylabel("Number of Notes")
plt.show()
plt.figure()
sns.scatterplot(data=midi_df, x='avg_pitch', y='polyphony_score', alpha=0.6)
plt.title("Average Pitch vs. Polyphony Score")
plt.xlabel("Average Pitch (MIDI number)")
plt.ylabel("Polyphony Score")
plt.show()
midi_df['base_name'] = midi_df['filename'].str.replace('.mid', '', regex=False)
merged_df = midi_df.merge(metadata_df, left_on='base_name', right_on='audio_name', how='left')
filtered_df = merged_df[
(merged_df['giant_midi_piano'] == 1) &
(merged_df['piano_solo_prob'] > 0.5) &
(merged_df['surname_in_youtube_title'] == 1)
]
top_nations = filtered_df['nationality'].value_counts().nlargest(10)
filtered_nations_df = filtered_df[filtered_df['nationality'].isin(top_nations.index)]
plt.figure(figsize=(10, 6))
sns.countplot(data=filtered_nations_df, y='nationality', order=top_nations.index)
plt.title("Top 10 Most Common Composer Nationalities")
plt.xlabel("Number of Pieces")
plt.ylabel("Nationality")
plt.show()
filtered_df = filtered_df.copy()
filtered_df['birth'] = pd.to_numeric(filtered_df['birth'], errors='coerce')
filtered_df['death'] = pd.to_numeric(filtered_df['death'], errors='coerce')
filtered_df['century'] = (filtered_df['birth'] // 100 + 1).fillna("Unknown")
filtered_df['century'] = filtered_df['century'].astype(str)
plt.figure(figsize=(10, 6))
sns.countplot(data=filtered_df, x='century', order=sorted(filtered_df['century'].unique()))
plt.title("Distribution of Composer Centuries")
plt.xlabel("Century")
plt.ylabel("Number of Pieces")
plt.show()
sns.histplot(filtered_df['piano_solo_prob'], bins=20)
<Axes: xlabel='piano_solo_prob', ylabel='Count'>
plt.figure(figsize=(10, 6))
sns.heatmap(filtered_df[['duration_sec', 'n_notes', 'avg_pitch', 'polyphony_score']].corr(), annot=True, cmap="coolwarm")
plt.title("Correlation Between Musical Features")
plt.show()
# Baseline RNN Model
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from miditoolkit import MidiFile
from glob import glob
import numpy as np
import random
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from miditok import REMI, TokenizerConfig
import math
from collections import Counter
from miditok import TokSequence
random.seed(42)
main_dir = "data/midis"
all_midi_paths = [os.path.join(main_dir, f) for f in os.listdir(main_dir) if f.endswith(".mid")]
random.shuffle(all_midi_paths)
train_paths = all_midi_paths[:100]
val_paths = all_midi_paths[100:125]
test_paths = all_midi_paths[125:150]
config = TokenizerConfig(num_velocities=16, use_chords=False, use_programs=False)
tokenizer = REMI(config)
tokenizer.train(vocab_size=1000, files_paths=train_paths)
print("Tokenizer complete")
class MusicDataset(Dataset):
def __init__(self, midi_paths, tokenizer, seq_len=256):
self.data = []
self.token_counts = Counter()
self.total_tokens = 0
self.seq_len = seq_len
for path in midi_paths:
try:
midi = MidiFile(path)
token_seq = tokenizer(midi)[0].tokens
token_ids = [tokenizer[token] for token in token_seq if token in tokenizer.vocab]
self.token_counts.update(token_ids)
self.total_tokens += len(token_ids)
for i in range(0, len(token_ids) - seq_len, seq_len):
self.data.append((token_ids[i:i+seq_len], token_ids[i+1:i+seq_len+1]))
except Exception as e:
print(f"[ERROR] {path}: {e}")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
x, y = self.data[idx]
return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)
train_dataset = MusicDataset(train_paths, tokenizer)
val_dataset = MusicDataset(val_paths, tokenizer)
test_dataset = MusicDataset(test_paths, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last=True)
print("All dataloaders ready")
Tokenizer complete
/tmp/ipykernel_1630/3806623913.py:50: UserWarning: You are using a depreciated `miditoolkit.MidiFile` object. MidiTokis now (>v3.0.0) using symusic.Score as MIDI backend. Your file willbe converted on the fly, however please consider using symusic. token_seq = tokenizer(midi)[0].tokens
All dataloaders ready
class RNNModel(torch.nn.Module):
def __init__(self, vocab_size, embed_size=256, hidden_size=512, num_layers=2):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.RNN(embed_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, vocab_size)
def forward(self, x, hidden=None):
x = self.embedding(x)
out, hidden = self.rnn(x, hidden)
return self.fc(out), hidden
import matplotlib.pyplot as plt
def plot_losses(epoch_losses):
train_losses, val_losses = zip(*epoch_losses)
epochs = range(1, len(train_losses) + 1)
plt.plot(epochs, train_losses, label="Train Loss")
plt.plot(epochs, val_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training & Validation Loss")
plt.legend()
plt.grid(True)
plt.show()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def train_model(model, train_loader, val_loader, epochs=10, lr=1e-3, ckpt_path="best_model_rnn.pt"):
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
best_val_loss = float("inf")
epoch_losses = []
for epoch in range(epochs):
model.train()
total_loss = 0
for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
x, y = x.to(device), y.to(device)
logits, _ = model(x)
loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
total_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_train_loss = total_loss / len(train_loader)
val_loss = evaluate_loss(model, val_loader)
print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f}")
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), ckpt_path)
print(" Checkpointed best model.")
epoch_losses.append((avg_train_loss, val_loss))
return epoch_losses
def evaluate_loss(model, dataloader):
model.eval()
total_loss = 0
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
for x, y in dataloader:
x, y = x.to(device), y.to(device)
logits, _ = model(x)
loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
total_loss += loss.item()
return total_loss / len(dataloader)
def generate(model, start_token_id, max_len=1000):
model.eval()
idxs = [start_token_id]
inp = torch.tensor([[start_token_id]], device=device)
hidden = None
for _ in range(max_len - 1):
logits, hidden = model(inp, hidden) if hasattr(model, "rnn") else (model(inp), None)
next_token = torch.multinomial(torch.softmax(logits[0, -1], dim=-1), 1).item()
idxs.append(next_token)
inp = torch.tensor([[next_token]], device=device)
return idxs
def save_generated(ids, out_file, tokenizer):
if not hasattr(tokenizer, 'vocab_inv'):
tokenizer.vocab_inv = {v: k for k, v in tokenizer.vocab.items()}
tokens = [tokenizer.vocab_inv[i] for i in ids]
tok_sequence = TokSequence(tokens=tokens)
midi = tokenizer.decode([tok_sequence])
midi.dump_midi(out_file)
print(f" Saved generated MIDI to: {out_file}")
vocab_size = len(tokenizer.vocab)
model = RNNModel(vocab_size)
losses = train_model(model, train_loader, val_loader, epochs=20, lr=1e-3)
plot_losses(losses)
model.load_state_dict(torch.load("best_model_rnn.pt"))
start_token = tokenizer.vocab["Bar_None"]
generated_ids = generate(model, start_token, max_len=256)
save_generated(generated_ids, "sample_output_test_rnn.mid", tokenizer)
Epoch 1/20: 100%|██████████| 59/59 [00:02<00:00, 24.34it/s]
Epoch 1 | Train Loss: 3.4076 | Val Loss: 3.0708 Checkpointed best model.
Epoch 2/20: 100%|██████████| 59/59 [00:02<00:00, 27.93it/s]
Epoch 2 | Train Loss: 3.0016 | Val Loss: 2.9523 Checkpointed best model.
Epoch 3/20: 100%|██████████| 59/59 [00:02<00:00, 27.92it/s]
Epoch 3 | Train Loss: 2.8642 | Val Loss: 2.8086 Checkpointed best model.
Epoch 4/20: 100%|██████████| 59/59 [00:02<00:00, 27.89it/s]
Epoch 4 | Train Loss: 2.7517 | Val Loss: 2.7418 Checkpointed best model.
Epoch 5/20: 100%|██████████| 59/59 [00:02<00:00, 27.97it/s]
Epoch 5 | Train Loss: 2.6997 | Val Loss: 2.7032 Checkpointed best model.
Epoch 6/20: 100%|██████████| 59/59 [00:02<00:00, 28.10it/s]
Epoch 6 | Train Loss: 2.6623 | Val Loss: 2.6779 Checkpointed best model.
Epoch 7/20: 100%|██████████| 59/59 [00:02<00:00, 28.12it/s]
Epoch 7 | Train Loss: 2.6249 | Val Loss: 2.6529 Checkpointed best model.
Epoch 8/20: 100%|██████████| 59/59 [00:02<00:00, 28.10it/s]
Epoch 8 | Train Loss: 2.6002 | Val Loss: 2.6391 Checkpointed best model.
Epoch 9/20: 100%|██████████| 59/59 [00:02<00:00, 28.09it/s]
Epoch 9 | Train Loss: 2.5716 | Val Loss: 2.6227 Checkpointed best model.
Epoch 10/20: 100%|██████████| 59/59 [00:02<00:00, 28.06it/s]
Epoch 10 | Train Loss: 2.5515 | Val Loss: 2.6138 Checkpointed best model.
Epoch 11/20: 100%|██████████| 59/59 [00:02<00:00, 28.06it/s]
Epoch 11 | Train Loss: 2.5282 | Val Loss: 2.5981 Checkpointed best model.
Epoch 12/20: 100%|██████████| 59/59 [00:02<00:00, 28.01it/s]
Epoch 12 | Train Loss: 2.5119 | Val Loss: 2.6032
Epoch 13/20: 100%|██████████| 59/59 [00:02<00:00, 27.90it/s]
Epoch 13 | Train Loss: 2.4995 | Val Loss: 2.5937 Checkpointed best model.
Epoch 14/20: 100%|██████████| 59/59 [00:02<00:00, 27.96it/s]
Epoch 14 | Train Loss: 2.4810 | Val Loss: 2.5918 Checkpointed best model.
Epoch 15/20: 100%|██████████| 59/59 [00:02<00:00, 27.94it/s]
Epoch 15 | Train Loss: 2.4663 | Val Loss: 2.5962
Epoch 16/20: 100%|██████████| 59/59 [00:02<00:00, 27.95it/s]
Epoch 16 | Train Loss: 2.4485 | Val Loss: 2.5818 Checkpointed best model.
Epoch 17/20: 100%|██████████| 59/59 [00:02<00:00, 27.99it/s]
Epoch 17 | Train Loss: 2.4371 | Val Loss: 2.5817 Checkpointed best model.
Epoch 18/20: 100%|██████████| 59/59 [00:02<00:00, 27.98it/s]
Epoch 18 | Train Loss: 2.4254 | Val Loss: 2.5843
Epoch 19/20: 100%|██████████| 59/59 [00:02<00:00, 27.98it/s]
Epoch 19 | Train Loss: 2.4067 | Val Loss: 2.5819
Epoch 20/20: 100%|██████████| 59/59 [00:02<00:00, 27.97it/s]
Epoch 20 | Train Loss: 2.3982 | Val Loss: 2.5915
Saved generated MIDI to: sample_output_test_rnn.mid
import pretty_midi
import numpy as np
from IPython.display import Audio
def midi_to_audio(midi_path, sample_rate=44100):
pm = pretty_midi.PrettyMIDI(midi_path)
audio = pm.synthesize(fs=sample_rate)
return Audio(audio, rate=sample_rate)
midi_to_audio("sample_output_test_rnn.mid")
import os
import glob
import torch
import pretty_midi
import numpy as np
import random
from torch import nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from IPython.display import Audio, display
class MIDIPreprocessor:
def __init__(self, sequence_length=100):
self.sequence_length = sequence_length
self.pitch_range = (21, 109)
self.num_pitches = self.pitch_range[1] - self.pitch_range[0] + 1
print(f"Preprocessor initialized: pitch range {self.pitch_range}, sequence length {self.sequence_length}")
def encode_midi(self, midi_path):
# print(f"Processing MIDI file: {midi_path}")
try:
midi = pretty_midi.PrettyMIDI(midi_path)
notes = []
for instrument in midi.instruments:
if not instrument.is_drum:
for note in instrument.notes:
if self.pitch_range[0] <= note.pitch <= self.pitch_range[1]:
notes.append(note.pitch - self.pitch_range[0])
# print(f"Extracted {len(notes)} notes")
return notes
except Exception as e:
print(f"Error processing {midi_path}: {e}")
return []
def build_sequences(self, all_notes):
print(f"Building note sequences...")
sequences = []
for i in range(len(all_notes) - self.sequence_length):
seq = all_notes[i:i + self.sequence_length + 1]
sequences.append(seq)
print(f"Built {len(sequences)} sequences")
return sequences
class MIDIDataset(Dataset):
def __init__(self, midi_dir, preprocessor, file_limit=100):
print(f"Loading MIDI dataset from: {midi_dir}")
self.preprocessor = preprocessor
all_notes = []
midi_files = glob.glob(os.path.join(midi_dir, "*.mid"))[:file_limit]
# print(f"Using first {len(midi_files)} MIDI files.")
if len(midi_files) == 0:
raise ValueError("No MIDI files found.")
for i, file in enumerate(midi_files):
# print(f"[{i+1}/{len(midi_files)}] Processing: {file}")
notes = preprocessor.encode_midi(file)
all_notes += notes
self.notes = all_notes
self.sequences = self.preprocessor.build_sequences(self.notes)
if len(self.sequences) == 0:
raise ValueError("No sequences built.")
print(f"Dataset ready: {len(self.sequences)} sequences")
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
sequence = self.sequences[idx]
input_seq = torch.tensor(sequence[:-1], dtype=torch.long)
target_seq = torch.tensor(sequence[1:], dtype=torch.long)
return input_seq, target_seq
class Task1(nn.Module):
def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, x):
embedded = self.embedding(x)
out, _ = self.lstm(embedded)
out = self.fc(out)
return out
import torch
import matplotlib.pyplot as plt
from torch import nn
def train_model(model, train_loader, val_loader, preprocessor, epochs=10, lr=0.001, file_limit=100, base_path="symbolic_model"):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Starting training on {device}...")
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
model.train()
train_losses = []
val_losses = []
for epoch in range(epochs):
print(f"\nEpoch {epoch + 1}/{epochs}")
total_train_loss = 0
model.train()
for batch_idx, (x, y) in enumerate(train_loader):
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
output = model(x)
loss = criterion(output.view(-1, output.size(-1)), y.view(-1))
loss.backward()
optimizer.step()
total_train_loss += loss.item()
if batch_idx % 500 == 0:
print(f" Batch {batch_idx}/{len(train_loader)}, Train Loss: {loss.item():.4f}")
avg_train_loss = total_train_loss / len(train_loader)
train_losses.append(avg_train_loss)
model.eval()
total_val_loss = 0
with torch.no_grad():
for x_val, y_val in val_loader:
x_val = x_val.to(device)
y_val = y_val.to(device)
output = model(x_val)
val_loss = criterion(output.view(-1, output.size(-1)), y_val.view(-1))
total_val_loss += val_loss.item()
avg_val_loss = total_val_loss / len(val_loader)
val_losses.append(avg_val_loss)
print(f"Epoch {epoch+1} complete. Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
torch.save({
"model_state_dict": model.state_dict(),
"sequence_length": preprocessor.sequence_length,
"pitch_range": preprocessor.pitch_range,
}, f"{base_path}_f{file_limit}_e{epoch+1}.pth")
print(f"Model saved to {base_path}_f{file_limit}_e{epoch+1}.pth")
# Plot losses
plt.figure(figsize=(8, 5))
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("loss_plot.png")
plt.show()
print("Loss plot saved as 'loss_plot.png'")
def load_model(model_path, vocab_size):
checkpoint = torch.load(model_path, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
model = Task1(vocab_size)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return model, checkpoint["sequence_length"], checkpoint["pitch_range"]
def generate_music(model, start_seq, length=200, temperature=1.0, seed=None):
if seed is not None:
torch.manual_seed(seed)
model.eval()
device = next(model.parameters()).device
input_seq = torch.tensor(start_seq, dtype=torch.long).unsqueeze(0).to(device)
generated = start_seq[:]
with torch.no_grad():
for _ in range(length):
output = model(input_seq)[0]
if output.dim() == 3:
logits = output[0, -1, :]
elif output.dim() == 2:
logits = output[-1]
else:
raise ValueError("Unexpected model output shape.")
logits = logits / temperature
probs = torch.softmax(logits, dim=0)
next_note = torch.multinomial(probs, num_samples=1).item()
generated.append(next_note)
input_seq = torch.tensor(generated[-len(start_seq):], dtype=torch.long).unsqueeze(0).to(device)
return generated
def save_midi(note_sequence, file_path, preprocessor, tempo=120):
pm = pretty_midi.PrettyMIDI()
instrument = pretty_midi.Instrument(program=0)
start_time = 0.0
step = 60.0 / tempo
for pitch in note_sequence:
note = pretty_midi.Note(
velocity=100,
pitch=pitch + preprocessor.pitch_range[0],
start=start_time,
end=start_time + step
)
instrument.notes.append(note)
start_time += step
pm.instruments.append(instrument)
pm.write(file_path)
print(f"MIDI saved: {file_path}")
def play_midi(file_path):
midi_data = pretty_midi.PrettyMIDI(file_path)
audio_data = midi_data.synthesize()
display(Audio(audio_data, rate=44100))
from torch.utils.data import DataLoader, random_split
preprocessor = MIDIPreprocessor(sequence_length=50)
dataset = MIDIDataset("data/midis", preprocessor, file_limit=100)
val_size = int(0.1 * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
Preprocessor initialized: pitch range (21, 109), sequence length 50 Loading MIDI dataset from: data/midis Building note sequences... Built 336358 sequences Dataset ready: 336358 sequences
model = Task1(vocab_size=preprocessor.num_pitches)
train_model(model, train_loader, val_loader, preprocessor, epochs=20)
Starting training on cuda... Epoch 1/20 Batch 0/4731, Train Loss: 4.4866 Batch 500/4731, Train Loss: 3.1311 Batch 1000/4731, Train Loss: 2.8842 Batch 1500/4731, Train Loss: 2.8125 Batch 2000/4731, Train Loss: 2.6422 Batch 2500/4731, Train Loss: 2.4669 Batch 3000/4731, Train Loss: 2.3412 Batch 3500/4731, Train Loss: 2.0884 Batch 4000/4731, Train Loss: 2.0008 Batch 4500/4731, Train Loss: 1.9386 Epoch 1 complete. Train Loss: 2.5786, Val Loss: 2.0064 Model saved to symbolic_model_f100_e1.pth Epoch 2/20 Batch 0/4731, Train Loss: 1.9593 Batch 500/4731, Train Loss: 1.8270 Batch 1000/4731, Train Loss: 1.7251 Batch 1500/4731, Train Loss: 1.7805 Batch 2000/4731, Train Loss: 1.6123 Batch 2500/4731, Train Loss: 1.6750 Batch 3000/4731, Train Loss: 1.6325 Batch 3500/4731, Train Loss: 1.5674 Batch 4000/4731, Train Loss: 1.5423 Batch 4500/4731, Train Loss: 1.5208 Epoch 2 complete. Train Loss: 1.7081, Val Loss: 1.5302 Model saved to symbolic_model_f100_e2.pth Epoch 3/20 Batch 0/4731, Train Loss: 1.4920 Batch 500/4731, Train Loss: 1.4557 Batch 1000/4731, Train Loss: 1.4258 Batch 1500/4731, Train Loss: 1.4187 Batch 2000/4731, Train Loss: 1.3488 Batch 2500/4731, Train Loss: 1.3993 Batch 3000/4731, Train Loss: 1.3231 Batch 3500/4731, Train Loss: 1.4511 Batch 4000/4731, Train Loss: 1.3839 Batch 4500/4731, Train Loss: 1.2585 Epoch 3 complete. Train Loss: 1.3851, Val Loss: 1.3328 Model saved to symbolic_model_f100_e3.pth Epoch 4/20 Batch 0/4731, Train Loss: 1.3088 Batch 500/4731, Train Loss: 1.3380 Batch 1000/4731, Train Loss: 1.3364 Batch 1500/4731, Train Loss: 1.2599 Batch 2000/4731, Train Loss: 1.2737 Batch 2500/4731, Train Loss: 1.2240 Batch 3000/4731, Train Loss: 1.2105 Batch 3500/4731, Train Loss: 1.1997 Batch 4000/4731, Train Loss: 1.2086 Batch 4500/4731, Train Loss: 1.1723 Epoch 4 complete. Train Loss: 1.2356, Val Loss: 1.2218 Model saved to symbolic_model_f100_e4.pth Epoch 5/20 Batch 0/4731, Train Loss: 1.1673 Batch 500/4731, Train Loss: 1.2181 Batch 1000/4731, Train Loss: 1.1913 Batch 1500/4731, Train Loss: 1.1767 Batch 2000/4731, Train Loss: 1.2371 Batch 2500/4731, Train Loss: 1.2301 Batch 3000/4731, Train Loss: 1.0791 Batch 3500/4731, Train Loss: 1.1235 Batch 4000/4731, Train Loss: 1.1507 Batch 4500/4731, Train Loss: 1.0776 Epoch 5 complete. Train Loss: 1.1516, Val Loss: 1.1596 Model saved to symbolic_model_f100_e5.pth Epoch 6/20 Batch 0/4731, Train Loss: 1.0774 Batch 500/4731, Train Loss: 1.1313 Batch 1000/4731, Train Loss: 1.1265 Batch 1500/4731, Train Loss: 1.1412 Batch 2000/4731, Train Loss: 1.0861 Batch 2500/4731, Train Loss: 1.0883 Batch 3000/4731, Train Loss: 1.1273 Batch 3500/4731, Train Loss: 1.0519 Batch 4000/4731, Train Loss: 1.0771 Batch 4500/4731, Train Loss: 1.1455 Epoch 6 complete. Train Loss: 1.0974, Val Loss: 1.1170 Model saved to symbolic_model_f100_e6.pth Epoch 7/20 Batch 0/4731, Train Loss: 1.0939 Batch 500/4731, Train Loss: 1.1123 Batch 1000/4731, Train Loss: 1.0623 Batch 1500/4731, Train Loss: 1.0700 Batch 2000/4731, Train Loss: 1.0477 Batch 2500/4731, Train Loss: 1.1031 Batch 3000/4731, Train Loss: 0.9900 Batch 3500/4731, Train Loss: 1.0420 Batch 4000/4731, Train Loss: 1.0634 Batch 4500/4731, Train Loss: 1.0290 Epoch 7 complete. Train Loss: 1.0581, Val Loss: 1.0869 Model saved to symbolic_model_f100_e7.pth Epoch 8/20 Batch 0/4731, Train Loss: 1.0379 Batch 500/4731, Train Loss: 1.0057 Batch 1000/4731, Train Loss: 1.0283 Batch 1500/4731, Train Loss: 0.9872 Batch 2000/4731, Train Loss: 1.0183 Batch 2500/4731, Train Loss: 1.0183 Batch 3000/4731, Train Loss: 1.0018 Batch 3500/4731, Train Loss: 0.9894 Batch 4000/4731, Train Loss: 1.0257 Batch 4500/4731, Train Loss: 0.9834 Epoch 8 complete. Train Loss: 1.0287, Val Loss: 1.0570 Model saved to symbolic_model_f100_e8.pth Epoch 9/20 Batch 0/4731, Train Loss: 0.9979 Batch 500/4731, Train Loss: 0.9744 Batch 1000/4731, Train Loss: 0.9818 Batch 1500/4731, Train Loss: 1.0085 Batch 2000/4731, Train Loss: 1.0561 Batch 2500/4731, Train Loss: 0.9815 Batch 3000/4731, Train Loss: 1.0033 Batch 3500/4731, Train Loss: 0.9868 Batch 4000/4731, Train Loss: 0.9614 Batch 4500/4731, Train Loss: 1.0136 Epoch 9 complete. Train Loss: 1.0043, Val Loss: 1.0344 Model saved to symbolic_model_f100_e9.pth Epoch 10/20 Batch 0/4731, Train Loss: 0.9829 Batch 500/4731, Train Loss: 1.0075 Batch 1000/4731, Train Loss: 0.9666 Batch 1500/4731, Train Loss: 1.0069 Batch 2000/4731, Train Loss: 0.9614 Batch 2500/4731, Train Loss: 0.9906 Batch 3000/4731, Train Loss: 0.9751 Batch 3500/4731, Train Loss: 1.0184 Batch 4000/4731, Train Loss: 0.9644 Batch 4500/4731, Train Loss: 1.0216 Epoch 10 complete. Train Loss: 0.9842, Val Loss: 1.0179 Model saved to symbolic_model_f100_e10.pth Epoch 11/20 Batch 0/4731, Train Loss: 0.9614 Batch 500/4731, Train Loss: 0.9420 Batch 1000/4731, Train Loss: 0.9650 Batch 1500/4731, Train Loss: 0.9655 Batch 2000/4731, Train Loss: 0.9812 Batch 2500/4731, Train Loss: 0.9396 Batch 3000/4731, Train Loss: 0.9947 Batch 3500/4731, Train Loss: 1.0044 Batch 4000/4731, Train Loss: 0.9412 Batch 4500/4731, Train Loss: 0.9848 Epoch 11 complete. Train Loss: 0.9671, Val Loss: 1.0031 Model saved to symbolic_model_f100_e11.pth Epoch 12/20 Batch 0/4731, Train Loss: 0.9281 Batch 500/4731, Train Loss: 0.9982 Batch 1000/4731, Train Loss: 0.9070 Batch 1500/4731, Train Loss: 1.0118 Batch 2000/4731, Train Loss: 0.9661 Batch 2500/4731, Train Loss: 0.9760 Batch 3000/4731, Train Loss: 0.9840 Batch 3500/4731, Train Loss: 0.9229 Batch 4000/4731, Train Loss: 0.9543 Batch 4500/4731, Train Loss: 0.9869 Epoch 12 complete. Train Loss: 0.9520, Val Loss: 0.9873 Model saved to symbolic_model_f100_e12.pth Epoch 13/20 Batch 0/4731, Train Loss: 0.9392 Batch 500/4731, Train Loss: 0.8681 Batch 1000/4731, Train Loss: 0.9210 Batch 1500/4731, Train Loss: 0.9989 Batch 2000/4731, Train Loss: 1.0534 Batch 2500/4731, Train Loss: 0.8954 Batch 3000/4731, Train Loss: 0.9963 Batch 3500/4731, Train Loss: 0.9660 Batch 4000/4731, Train Loss: 0.9737 Batch 4500/4731, Train Loss: 0.9184 Epoch 13 complete. Train Loss: 0.9386, Val Loss: 0.9761 Model saved to symbolic_model_f100_e13.pth Epoch 14/20 Batch 0/4731, Train Loss: 0.8807 Batch 500/4731, Train Loss: 0.9447 Batch 1000/4731, Train Loss: 0.9628 Batch 1500/4731, Train Loss: 0.9209 Batch 2000/4731, Train Loss: 0.9491 Batch 2500/4731, Train Loss: 0.9531 Batch 3000/4731, Train Loss: 0.9424 Batch 3500/4731, Train Loss: 0.9481 Batch 4000/4731, Train Loss: 0.9102 Batch 4500/4731, Train Loss: 0.9232 Epoch 14 complete. Train Loss: 0.9268, Val Loss: 0.9665 Model saved to symbolic_model_f100_e14.pth Epoch 15/20 Batch 0/4731, Train Loss: 0.8969 Batch 500/4731, Train Loss: 0.9352 Batch 1000/4731, Train Loss: 0.8824 Batch 1500/4731, Train Loss: 0.8754 Batch 2000/4731, Train Loss: 0.8956 Batch 2500/4731, Train Loss: 0.9022 Batch 3000/4731, Train Loss: 0.8872 Batch 3500/4731, Train Loss: 0.9434 Batch 4000/4731, Train Loss: 0.8789 Batch 4500/4731, Train Loss: 0.8746 Epoch 15 complete. Train Loss: 0.9161, Val Loss: 0.9536 Model saved to symbolic_model_f100_e15.pth Epoch 16/20 Batch 0/4731, Train Loss: 0.8407 Batch 500/4731, Train Loss: 0.9048 Batch 1000/4731, Train Loss: 0.9198 Batch 1500/4731, Train Loss: 0.8823 Batch 2000/4731, Train Loss: 0.9031 Batch 2500/4731, Train Loss: 0.9427 Batch 3000/4731, Train Loss: 0.9086 Batch 3500/4731, Train Loss: 0.8719 Batch 4000/4731, Train Loss: 0.8657 Batch 4500/4731, Train Loss: 0.9103 Epoch 16 complete. Train Loss: 0.9061, Val Loss: 0.9461 Model saved to symbolic_model_f100_e16.pth Epoch 17/20 Batch 0/4731, Train Loss: 0.8703 Batch 500/4731, Train Loss: 0.9346 Batch 1000/4731, Train Loss: 0.8435 Batch 1500/4731, Train Loss: 0.8955 Batch 2000/4731, Train Loss: 0.8907 Batch 2500/4731, Train Loss: 0.8927 Batch 3000/4731, Train Loss: 0.9145 Batch 3500/4731, Train Loss: 0.9361 Batch 4000/4731, Train Loss: 0.8999 Batch 4500/4731, Train Loss: 0.9123 Epoch 17 complete. Train Loss: 0.8966, Val Loss: 0.9379 Model saved to symbolic_model_f100_e17.pth Epoch 18/20 Batch 0/4731, Train Loss: 0.8617 Batch 500/4731, Train Loss: 0.8669 Batch 1000/4731, Train Loss: 0.8392 Batch 1500/4731, Train Loss: 0.8924 Batch 2000/4731, Train Loss: 0.9068 Batch 2500/4731, Train Loss: 0.8603 Batch 3000/4731, Train Loss: 0.8543 Batch 3500/4731, Train Loss: 0.8754 Batch 4000/4731, Train Loss: 0.8908 Batch 4500/4731, Train Loss: 0.8644 Epoch 18 complete. Train Loss: 0.8885, Val Loss: 0.9289 Model saved to symbolic_model_f100_e18.pth Epoch 19/20 Batch 0/4731, Train Loss: 0.8545 Batch 500/4731, Train Loss: 0.8145 Batch 1000/4731, Train Loss: 0.9053 Batch 1500/4731, Train Loss: 0.9408 Batch 2000/4731, Train Loss: 0.8843 Batch 2500/4731, Train Loss: 0.8901 Batch 3000/4731, Train Loss: 0.9212 Batch 3500/4731, Train Loss: 0.9347 Batch 4000/4731, Train Loss: 0.8735 Batch 4500/4731, Train Loss: 0.8816 Epoch 19 complete. Train Loss: 0.8801, Val Loss: 0.9244 Model saved to symbolic_model_f100_e19.pth Epoch 20/20 Batch 0/4731, Train Loss: 0.9042 Batch 500/4731, Train Loss: 0.8567 Batch 1000/4731, Train Loss: 0.8567 Batch 1500/4731, Train Loss: 0.8742 Batch 2000/4731, Train Loss: 0.9002 Batch 2500/4731, Train Loss: 0.9080 Batch 3000/4731, Train Loss: 0.8216 Batch 3500/4731, Train Loss: 0.8231 Batch 4000/4731, Train Loss: 0.8374 Batch 4500/4731, Train Loss: 0.8503 Epoch 20 complete. Train Loss: 0.8727, Val Loss: 0.9160 Model saved to symbolic_model_f100_e20.pth
Loss plot saved as 'loss_plot.png'
def midi_generation(model_path, output_path="generated_random.mid", length=200, temperature=1.5):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
preprocessor = MIDIPreprocessor(sequence_length=50)
vocab_size = preprocessor.num_pitches
checkpoint = torch.load(model_path, map_location=device)
model = Task1(vocab_size)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
start_seq = [random.randint(0, vocab_size - 1) for _ in range(preprocessor.sequence_length)]
generated_notes = generate_music(model, start_seq, length=length, temperature=temperature, seed=None)
save_midi(generated_notes, output_path, preprocessor)
midi_generation("symbolic_model_f100_e20.pth", output_path="trial4.mid", length=300, temperature=1.0)
Preprocessor initialized: pitch range (21, 109), sequence length 50 MIDI saved: trial4.mid
# !pip install music21
from glob import glob
import random
import numpy as np
import pandas as pd
from numpy.random import choice
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from miditok import REMI, TokenizerConfig
from mido import Message, MidiFile, MidiTrack, MetaMessage, bpm2tempo
from music21 import midi, chord, note
device = torch.device("cuda")
midi_files = glob('data/midis/*.mid')
len(midi_files)
10854
# Set the random seed
random.seed(42)
dataroot = '/data/midis'
sample_files = random.sample(midi_files, 1500)
print(sample_files[0])
data/midis/Ladurner, Ignace Antoine, 3 Keyboard Sonatas, Op.4, fzYLT5JMtZk.mid
def midi_preprocess(midi_file, max_len = 30):
midi = MidiFile(midi_file)
note_times = {}
melodies = []
harmonies = []
end_time = 30
current_time = 0
for msg in midi.play():
current_time += msg.time
if current_time >= end_time:
break
if msg.type == 'note_on' and msg.velocity > 0:
timestamp = msg.time
if timestamp not in note_times:
note_times[timestamp] = []
note_times[timestamp].append(msg.note)
for timestamp, notes in note_times.items():
if len(melodies) >= max_len:
break
melodies.append(notes[0])
harmonies.append(chord.Chord(notes).commonName)
if len(melodies) < max_len:
melodies += [0] * (max_len - len(melodies))
harmonies += ["Rest"] * (max_len - len(harmonies))
return melodies, harmonies
melodies = []
harmonies = []
for i in range(10):
print(i)
midi_file = sample_files[i]
melody, harmony = midi_preprocess(midi_file)
melodies.append(melody)
harmonies.append(harmony)
0 1 2 3 4 5 6 7 8 9
note_set = sorted(set([note for melody in melodies for note in melody]))
chord_set = sorted(set([chord for harmony in harmonies for chord in harmony]))
note_to_int = {note: i for i, note in enumerate(note_set)}
chord_to_int = {chord: i for i, chord in enumerate(chord_set)}
X_train = [[note_to_int[n] for n in melody] for melody in melodies]
y_train = [[chord_to_int[c] for c in harmony] for harmony in harmonies]
X_train = np.array(X_train)
y_train = np.array(y_train)
print(X_train)
print(y_train)
[[34 32 15 31 29 19 27 27 22 19 15 29 22 31 32 34 36 22 39 15 20 20 15 38 36 24 20 36 24 34] [29 7 20 12 24 25 17 30 24 32 30 25 34 24 24 25 24 25 36 8 13 25 27 34 27 30 12 24 25 19] [39 6 12 19 6 5 12 38 15 12 19 41 22 12 31 7 21 33 38 14 33 36 34 26 34 5 19 24 2 26] [20 20 20 20 20 20 32 32 32 32 32 32 22 35 4 15 27 20 30 24 18 21 20 32 1 20 23 18 10 18] [ 3 8 11 22 13 16 18 16 18 22 25 28 30 18 34 37 34 40 32 16 8 13 30 43 25 34 18 25 42 37] [32 34 18 24 27 20 9 18 27 32 34 17 25 29 29 25 32 29 37 9 9 15 32 36 37 41 12 39 36 32] [24 7 29 22 6 24 7 6 36 24 31 38 32 36 31 32 26 27 34 22 29 36 32 31 5 29 40 28 26 38] [15 18 22 30 33 23 15 20 17 23 43 38 39 41 39 22 42 34 17 35 39 32 30 18 15 6 34 27 7 22] [32 20 32 20 32 20 32 32 20 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] [20 22 24 25 27 25 24 29 22 27 29 27 25 24 27 24 22 20 29 22 20 19 32 17 34 32 36 29 34 26]] [[20 64 48 64 17 64 18 4 45 64 64 21 2 62 11 28 36 70 63 65 74 64 37 46 45 64 64 64 19 66] [64 64 50 70 64 64 48 74 72 64 25 56 64 9 64 64 69 69 7 64 68 64 64 64 44 64 64 64 64 64] [64 30 1 75 64 50 1 64 64 31 67 64 64 69 64 64 64 64 64 64 64 64 64 64 64 64 64 64 64 64] [64 59 30 18 13 22 64 64 64 64 64 64 64 64 64 64 64 64 64 64 64 0 64 64 8 64 13 64 7 64] [64 64 0 64 42 64 64 25 72 61 64 3 17 44 46 64 61 50 64 64 64 64 64 64 50 64 64 64 64 46] [64 64 64 38 46 49 64 64 34 64 64 64 16 58 55 50 64 68 0 2 64 64 3 64 64 64 74 64 64 15] [64 35 43 64 41 32 62 1 64 39 50 64 64 64 34 26 64 64 64 17 64 50 29 64 64 52 64 4 64 64] [64 40 57 6 9 30 27 24 47 1 64 53 1 64 64 57 70 9 60 9 50 13 64 5 24 50 13 18 10 64] [64 64 64 20 64 64 20 64 64 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23] [64 44 64 74 64 73 64 48 64 33 64 14 64 76 6 54 64 64 38 64 17 64 12 70 64 6 64 64 51 71]]
x_df = pd.DataFrame(X_train)
y_df = pd.DataFrame(y_train)
x_df.to_csv("training_x_test.csv", index=False)
y_df.to_csv("training_y_test.csv", index=False)
print("Single CSV file saved successfully!")
Single CSV file saved successfully!
df_x = pd.read_csv("training_x_test.csv")
df_y = pd.read_csv("training_y_test.csv")
X_train = df_x.values
y_train = df_y.values
print("X_train shape:", X_train.shape)
print("y_train shape:", y_train.shape)
X_train shape: (10, 30) y_train shape: (10, 30)
class LSTM(nn.Module):
def __init__(self, input_size, sequence_length, num_classes):
super(LSTM, self).__init__()
self.embedding = nn.Embedding(input_size, 64)
self.lstm = nn.LSTM(64, 128, batch_first=True)
self.fc1 = nn.Linear(128, 64)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(64, num_classes)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = self.embedding(x)
x, _ = self.lstm(x)
x = x[:, -1, :]
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return self.softmax(x)
model = LSTM(input_size=len(note_set), sequence_length=X_train.shape[1], num_classes=len(chord_set))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Embedding
model = Sequential([
Embedding(len(note_set), 64, input_length=X_train.shape[1]),
LSTM(128, return_sequences=True),
Dense(64, activation='relu'),
Dense(len(chord_set), activation='softmax')
])
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
losses = model.fit(X_train, y_train, epochs=50, batch_size=32)
2025-06-03 18:11:00.382468: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2025-06-03 18:11:00.382527: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2025-06-03 18:11:00.383859: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2025-06-03 18:11:00.390558: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2025-06-03 18:11:02.252067: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2025-06-03 18:11:02.254867: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2025-06-03 18:11:02.255419: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2025-06-03 18:11:02.262151: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2025-06-03 18:11:02.262703: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2025-06-03 18:11:02.263233: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2025-06-03 18:11:02.384097: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2025-06-03 18:11:02.384650: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2025-06-03 18:11:02.385168: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 2025-06-03 18:11:02.385634: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 10534 MB memory: -> device: 0, name: NVIDIA GeForce GTX 1080 Ti, pci bus id: 0000:c4:00.0, compute capability: 6.1
Epoch 1/50
2025-06-03 18:11:04.457748: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8902 2025-06-03 18:11:04.686499: I external/local_xla/xla/service/service.cc:168] XLA service 0x7f6ff8241f30 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices: 2025-06-03 18:11:04.686559: I external/local_xla/xla/service/service.cc:176] StreamExecutor device (0): NVIDIA GeForce GTX 1080 Ti, Compute Capability 6.1 2025-06-03 18:11:04.698810: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable. WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1748974264.826799 689 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
1/1 [==============================] - 3s 3s/step - loss: 4.3424 - accuracy: 0.0067 Epoch 2/50 1/1 [==============================] - 0s 8ms/step - loss: 4.3338 - accuracy: 0.3900 Epoch 3/50 1/1 [==============================] - 0s 7ms/step - loss: 4.3238 - accuracy: 0.5300 Epoch 4/50 1/1 [==============================] - 0s 9ms/step - loss: 4.3119 - accuracy: 0.5300 Epoch 5/50 1/1 [==============================] - 0s 8ms/step - loss: 4.2974 - accuracy: 0.5300 Epoch 6/50 1/1 [==============================] - 0s 9ms/step - loss: 4.2792 - accuracy: 0.5267 Epoch 7/50 1/1 [==============================] - 0s 9ms/step - loss: 4.2554 - accuracy: 0.5167 Epoch 8/50 1/1 [==============================] - 0s 9ms/step - loss: 4.2231 - accuracy: 0.4800 Epoch 9/50 1/1 [==============================] - 0s 9ms/step - loss: 4.1776 - accuracy: 0.4800 Epoch 10/50 1/1 [==============================] - 0s 8ms/step - loss: 4.1101 - accuracy: 0.4800 Epoch 11/50 1/1 [==============================] - 0s 10ms/step - loss: 4.0057 - accuracy: 0.4800 Epoch 12/50 1/1 [==============================] - 0s 9ms/step - loss: 3.8416 - accuracy: 0.4800 Epoch 13/50 1/1 [==============================] - 0s 8ms/step - loss: 3.6121 - accuracy: 0.4800 Epoch 14/50 1/1 [==============================] - 0s 8ms/step - loss: 3.4147 - accuracy: 0.4800 Epoch 15/50 1/1 [==============================] - 0s 8ms/step - loss: 3.3813 - accuracy: 0.4800 Epoch 16/50 1/1 [==============================] - 0s 8ms/step - loss: 3.4057 - accuracy: 0.4800 Epoch 17/50 1/1 [==============================] - 0s 8ms/step - loss: 3.3897 - accuracy: 0.4800 Epoch 18/50 1/1 [==============================] - 0s 9ms/step - loss: 3.3321 - accuracy: 0.4800 Epoch 19/50 1/1 [==============================] - 0s 8ms/step - loss: 3.2537 - accuracy: 0.4800 Epoch 20/50 1/1 [==============================] - 0s 8ms/step - loss: 3.1750 - accuracy: 0.4800 Epoch 21/50 1/1 [==============================] - 0s 9ms/step - loss: 3.1129 - accuracy: 0.4800 Epoch 22/50 1/1 [==============================] - 0s 9ms/step - loss: 3.0763 - accuracy: 0.4800 Epoch 23/50 1/1 [==============================] - 0s 8ms/step - loss: 3.0618 - accuracy: 0.4800 Epoch 24/50 1/1 [==============================] - 0s 8ms/step - loss: 3.0584 - accuracy: 0.4800 Epoch 25/50 1/1 [==============================] - 0s 9ms/step - loss: 3.0555 - accuracy: 0.4800 Epoch 26/50 1/1 [==============================] - 0s 8ms/step - loss: 3.0461 - accuracy: 0.4800 Epoch 27/50 1/1 [==============================] - 0s 9ms/step - loss: 3.0293 - accuracy: 0.4800 Epoch 28/50 1/1 [==============================] - 0s 9ms/step - loss: 3.0069 - accuracy: 0.4800 Epoch 29/50 1/1 [==============================] - 0s 8ms/step - loss: 2.9822 - accuracy: 0.4800 Epoch 30/50 1/1 [==============================] - 0s 8ms/step - loss: 2.9582 - accuracy: 0.4800 Epoch 31/50 1/1 [==============================] - 0s 9ms/step - loss: 2.9392 - accuracy: 0.4800 Epoch 32/50 1/1 [==============================] - 0s 9ms/step - loss: 2.9275 - accuracy: 0.4800 Epoch 33/50 1/1 [==============================] - 0s 8ms/step - loss: 2.9217 - accuracy: 0.4800 Epoch 34/50 1/1 [==============================] - 0s 8ms/step - loss: 2.9180 - accuracy: 0.4800 Epoch 35/50 1/1 [==============================] - 0s 8ms/step - loss: 2.9137 - accuracy: 0.4800 Epoch 36/50 1/1 [==============================] - 0s 8ms/step - loss: 2.9069 - accuracy: 0.4800 Epoch 37/50 1/1 [==============================] - 0s 8ms/step - loss: 2.8970 - accuracy: 0.4800 Epoch 38/50 1/1 [==============================] - 0s 9ms/step - loss: 2.8853 - accuracy: 0.4800 Epoch 39/50 1/1 [==============================] - 0s 9ms/step - loss: 2.8733 - accuracy: 0.4800 Epoch 40/50 1/1 [==============================] - 0s 9ms/step - loss: 2.8625 - accuracy: 0.4800 Epoch 41/50 1/1 [==============================] - 0s 8ms/step - loss: 2.8539 - accuracy: 0.4800 Epoch 42/50 1/1 [==============================] - 0s 8ms/step - loss: 2.8474 - accuracy: 0.4800 Epoch 43/50 1/1 [==============================] - 0s 8ms/step - loss: 2.8416 - accuracy: 0.4800 Epoch 44/50 1/1 [==============================] - 0s 8ms/step - loss: 2.8353 - accuracy: 0.4800 Epoch 45/50 1/1 [==============================] - 0s 8ms/step - loss: 2.8275 - accuracy: 0.4800 Epoch 46/50 1/1 [==============================] - 0s 8ms/step - loss: 2.8180 - accuracy: 0.4800 Epoch 47/50 1/1 [==============================] - 0s 8ms/step - loss: 2.8073 - accuracy: 0.4800 Epoch 48/50 1/1 [==============================] - 0s 9ms/step - loss: 2.7963 - accuracy: 0.4800 Epoch 49/50 1/1 [==============================] - 0s 8ms/step - loss: 2.7859 - accuracy: 0.4800 Epoch 50/50 1/1 [==============================] - 0s 8ms/step - loss: 2.7757 - accuracy: 0.4800
loss = losses.history['loss']
plt.plot(range(50), loss)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss vs. Epochs')
plt.show()
def get_melody(midi_file):
midi = MidiFile(midi_file)
melodies = []
end_time = 30
current_time = 0
for msg in midi.play():
current_time += msg.time
if current_time >= end_time:
break
if msg.type == 'note_on' and msg.velocity > 0:
melodies.append(msg.note)
return melodies
# Test file
midi_file = sample_files[1001]
melody = get_melody(midi_file)[:30]
print(midi_file)
print(melody)
data/midis/Joplin, Scott, A Breeze from Alabama, nXe43xnOEf4.mid [60, 55, 64, 67, 48, 60, 72, 63, 60, 57, 54, 57, 69, 68, 56, 69, 57, 71, 59, 72, 55, 60, 64, 55, 67, 52, 64, 63, 51, 64]
def generate_harmony(melody):
nums = [note_to_int[n] if n in note_to_int.keys() else 0 for n in melody]
nums = np.array([nums])
prediction = model.predict(nums)
harmony_predicted = [chord_set[np.argmax(p)] for p in prediction[0]]
return harmony_predicted
predicted_harmony = generate_harmony(melody)
print(predicted_harmony)
1/1 [==============================] - 0s 381ms/step ['note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note', 'note']
def harmony_midi(harmony):
midi_notes = []
for h in harmony:
try:
chord_obj = chord.Chord(h)
midi_notes.append(chord_obj.pitches[0].midi)
except:
try:
midi_notes.append(note.Note(h).pitch.midi)
except:
midi_notes.append(60)
return midi_notes
midi_harmony = harmony_midi(predicted_harmony)
print(midi_harmony)
[60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60]
def generate_midi(melody, harmony, output_file="generated_harmony.mid"):
midi_new = MidiFile()
melody_track = MidiTrack()
harmony_track = MidiTrack()
for m in melody:
melody_track.append(Message('note_on', note=m, velocity=64, time=200))
melody_track.append(Message('note_off', note=m, velocity=64, time=200))
for h in harmony:
harmony_track.append(Message('note_on', note=h, velocity=64, time=200))
harmony_track.append(Message('note_off', note=h, velocity=64, time=200))
midi_new.tracks.append(melody_track)
midi_new.tracks.append(harmony_track)
midi_new.save(output_file)
print(f"Saved generated harmony to {output_file}")
generate_midi(melody, midi_harmony)
Saved generated harmony to generated_harmony.mid
def play_midi(path):
mf = midi.MidiFile()
mf.open(path)
mf.read()
mf.close()
s = midi.translate.midiFileToStream(mf)
s.show('midi')
play_midi("generated_harmony.mid")
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from music21 import converter, note, chord, stream, roman
class Config:
max_seq_len = 512
latent_dim = 256
d_model = 384
n_heads = 6
n_layers = 4
dropout = 0.1
batch_size = 4
lr = 1e-4
grad_clip = 1.0
kl_anneal_epochs = 20
eps = 1e-8
temp = 1.2
top_k = 20
def build_vocab():
vocab = {"PAD": 0, "BOS": 1}
idx = 2
for p in range(21, 109):
for d in [0.25, 0.5, 1.0, 2.0]:
vocab[f"Note_{p}_{d}"] = idx
idx += 1
for fig in ["I", "ii", "iii", "IV", "V", "vi", "vii", "V7", "ii7", "I6", "V/V"]:
vocab[f"Roman_{fig}"] = idx
idx += 1
return vocab
vocab = build_vocab()
inv_vocab = {v: k for k, v in vocab.items()}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Music21ConditionalDataset(Dataset):
def __init__(self, midi_dir, max_seq_len=512):
self.paths = list(Path(midi_dir).glob("*.mid"))[:10]
self.max_seq_len = max_seq_len
self.pad_id = vocab["PAD"]
self.bos_id = vocab["BOS"]
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
try:
score = converter.parse(str(self.paths[idx]))
melody, harmony = self.extract_tokens(score)
melody = [self.bos_id] + melody[:self.max_seq_len - 1]
harmony = [self.bos_id] + harmony[:self.max_seq_len - 1]
melody += [self.pad_id] * (self.max_seq_len - len(melody))
harmony += [self.pad_id] * (self.max_seq_len - len(harmony))
return torch.LongTensor(melody), torch.LongTensor(harmony)
except:
dummy = [self.bos_id] + [self.pad_id] * (self.max_seq_len - 1)
return torch.LongTensor(dummy), torch.LongTensor(dummy)
def extract_tokens(self, score):
melody_tokens, chord_tokens = [], []
melody_part = score.parts[0] # Assume melody is in first part
key = score.analyze('key')
for el in melody_part.recurse().notesAndRests:
dur = round(el.duration.quarterLength, 2)
dur = min([0.25, 0.5, 1.0, 2.0], key=lambda x: abs(x - dur))
if isinstance(el, note.Note):
pitch = el.pitch.midi
tok = f"Note_{pitch}_{dur}"
melody_tokens.append(vocab.get(tok, self.pad_id))
else:
melody_tokens.append(self.pad_id)
harmony = score.chordify()
for el in harmony.flat.getElementsByClass(chord.Chord):
try:
rn = roman.romanNumeralFromChord(el, key)
tok = f"Roman_{rn.figure}"
chord_tokens.append(vocab.get(tok, self.pad_id))
except:
chord_tokens.append(self.pad_id)
return melody_tokens, chord_tokens
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=512):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class MusicVAE(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.embed = nn.Embedding(vocab_size, Config.d_model)
nn.init.normal_(self.embed.weight, mean=0.0, std=0.02)
self.pos_enc = PositionalEncoding(Config.d_model)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(Config.d_model, Config.n_heads, Config.d_model * 2, Config.dropout, batch_first=True, activation='gelu'),
num_layers=Config.n_layers
)
self._init_weights(self.encoder)
self.fc_mu = nn.Linear(Config.d_model, Config.latent_dim)
self.fc_logvar = nn.Linear(Config.d_model, Config.latent_dim)
nn.init.xavier_normal_(self.fc_mu.weight)
nn.init.xavier_normal_(self.fc_logvar.weight)
self.latent_proj = nn.Linear(Config.latent_dim, Config.d_model)
self.decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(Config.d_model, Config.n_heads, Config.d_model * 2, Config.dropout, batch_first=True, activation='gelu'),
num_layers=Config.n_layers
)
self._init_weights(self.decoder)
self.fc_out = nn.Linear(Config.d_model, vocab_size)
nn.init.zeros_(self.fc_out.bias)
def _init_weights(self, module):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_normal_(p)
def encode(self, melody):
src = self.embed(melody) * math.sqrt(Config.d_model)
src = self.pos_enc(src)
memory = self.encoder(src)
mu = self.fc_mu(memory.mean(1))
logvar = self.fc_logvar(memory.mean(1))
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar) + Config.eps
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, cond, target):
memory = self.encoder(self.pos_enc(self.embed(cond) * math.sqrt(Config.d_model)))
z_context = self.latent_proj(z).unsqueeze(1).repeat(1, memory.size(1), 1)
memory = memory + z_context
tgt = self.embed(target) * math.sqrt(Config.d_model)
tgt = self.pos_enc(tgt)
tgt_mask = nn.Transformer.generate_square_subsequent_mask(target.size(1), device=target.device)
out = self.decoder(tgt=tgt, memory=memory, tgt_mask=tgt_mask)
return self.fc_out(out)
def forward(self, cond, target):
mu, logvar = self.encode(cond)
z = self.reparameterize(mu, logvar)
logits = self.decode(z, cond, target[:, :-1])
return logits, mu, logvar
losses = []
def train_model(midi_dir, epochs=50):
dataset = Music21ConditionalDataset(midi_dir)
loader = DataLoader(dataset, batch_size=Config.batch_size, shuffle=False)
model = MusicVAE(len(vocab)).to(device)
opt = torch.optim.Adam(model.parameters(), lr=Config.lr)
for epoch in range(epochs):
model.train()
total_loss = 0
kl_weight = min(1.0, epoch / Config.kl_anneal_epochs)
for melody, harmony in loader:
melody, harmony = melody.to(device), harmony.to(device)
logits, mu, logvar = model(melody, harmony)
loss_rec = F.cross_entropy(logits.view(-1, logits.size(-1)), harmony[:, 1:].contiguous().view(-1), ignore_index=vocab["PAD"])
kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
loss = loss_rec + kl_weight * kl
opt.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), Config.grad_clip)
opt.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}")
losses.append(total_loss/len(loader))
torch.save(model.state_dict(), "musicvae_music21.pth")
def generate(model_path, melody_path, out_path="output.mid", max_tokens=256):
model = MusicVAE(len(vocab)).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()
dataset = Music21ConditionalDataset("")
melody_tokens, _ = dataset.extract_tokens(converter.parse(melody_path))
melody = [vocab["BOS"]] + melody_tokens[:Config.max_seq_len - 1]
melody += [vocab["PAD"]] * (Config.max_seq_len - len(melody))
melody = torch.LongTensor(melody).unsqueeze(0).to(device)
with torch.no_grad():
mu, logvar = model.encode(melody)
z = model.reparameterize(mu, logvar)
generated = [vocab["BOS"]]
for _ in range(max_tokens):
inp = torch.LongTensor([generated[-Config.max_seq_len + 1:]]).to(device)
out = model.decode(z, melody, inp)[0, -1]
probs = F.softmax(out / Config.temp, dim=0)
topk_probs, topk_idx = probs.topk(Config.top_k)
next_token = topk_idx[torch.multinomial(topk_probs, 1)].item()
if next_token == vocab["PAD"]:
break
generated.append(next_token)
# Reconstruct score with both melody and chords
melody_stream = stream.Part()
for tok_id in melody.squeeze().tolist():
tok = inv_vocab.get(tok_id, "PAD")
if tok.startswith("Note_"):
_, pitch, dur = tok.split("_")
n = note.Note(int(pitch), quarterLength=float(dur))
melody_stream.append(n)
harmony_stream = stream.Part()
for tok_id in generated:
tok = inv_vocab.get(tok_id, "PAD")
if tok.startswith("Roman_"):
fig = tok.replace("Roman_", "")
try:
rn = roman.RomanNumeral(fig, "C") # placeholder key
c = rn.pitchedCommonName.split()
chord_obj = chord.Chord([str(p) for p in rn.pitches])
chord_obj.quarterLength = 1.0
harmony_stream.append(chord_obj)
except:
pass
full_score = stream.Score()
full_score.insert(0, melody_stream)
full_score.insert(0, harmony_stream)
full_score.write("midi", fp=out_path)
train_model("data/midis")
/home/vsinha/.local/lib/python3.11/site-packages/music21/stream/base.py:3675: Music21DeprecationWarning: .flat is deprecated. Call .flatten() instead return self.iter().getElementsByClass(classFilterList)
Epoch 1, Loss: 4.9969 Epoch 2, Loss: 3.6185 Epoch 3, Loss: 3.2209 Epoch 4, Loss: 3.0361 Epoch 5, Loss: 2.8989 Epoch 6, Loss: 2.7716 Epoch 7, Loss: 2.6814 Epoch 8, Loss: 2.6512 Epoch 9, Loss: 2.5464 Epoch 10, Loss: 2.5089 Epoch 11, Loss: 2.4878 Epoch 12, Loss: 2.3944 Epoch 13, Loss: 2.3354 Epoch 14, Loss: 2.3007 Epoch 15, Loss: 2.2703 Epoch 16, Loss: 2.1886 Epoch 17, Loss: 2.1908 Epoch 18, Loss: 2.1190 Epoch 19, Loss: 2.0822 Epoch 20, Loss: 2.0290 Epoch 21, Loss: 2.0261 Epoch 22, Loss: 2.0352 Epoch 23, Loss: 1.9049 Epoch 24, Loss: 1.8353 Epoch 25, Loss: 1.8090 Epoch 26, Loss: 1.7589 Epoch 27, Loss: 1.7259 Epoch 28, Loss: 1.6826 Epoch 29, Loss: 1.6491 Epoch 30, Loss: 1.6276 Epoch 31, Loss: 1.6025 Epoch 32, Loss: 1.5488 Epoch 33, Loss: 1.5583 Epoch 34, Loss: 1.5575 Epoch 35, Loss: 1.4880 Epoch 36, Loss: 1.4543 Epoch 37, Loss: 1.4462 Epoch 38, Loss: 1.4125 Epoch 39, Loss: 1.3815 Epoch 40, Loss: 1.3378 Epoch 41, Loss: 1.3246 Epoch 42, Loss: 1.3538 Epoch 43, Loss: 1.2830 Epoch 44, Loss: 1.3144 Epoch 45, Loss: 1.2397 Epoch 46, Loss: 1.3221 Epoch 47, Loss: 1.2183 Epoch 48, Loss: 1.2184 Epoch 49, Loss: 1.1903 Epoch 50, Loss: 1.1305
import matplotlib.pyplot as plt
# Loss values you provided
loss_values = losses
# Corresponding epoch numbers
epochs = list(range(1, len(loss_values) + 1))
# Plotting
plt.figure(figsize=(10, 5))
plt.plot(epochs, loss_values, linestyle='-', color='royalblue', label='Training Loss')
plt.title('Training Loss vs Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, linestyle='--', alpha=0.5)
plt.legend()
plt.tight_layout()
plt.show()
generate("musicvae_music21.pth", "data/midis/Violette, Andrew, 3 Little Pieces, O2KofziUWQU.mid")
import pretty_midi
import numpy as np
from IPython.display import Audio
def midi_to_audio(midi_path, sample_rate=44100):
pm = pretty_midi.PrettyMIDI(midi_path)
audio = pm.synthesize(fs=sample_rate)
return Audio(audio, rate=sample_rate)
midi_to_audio("output.mid")